# Copyright (c) OpenMMLab. All rights reserved.
import os

import torch
import torch.nn as nn
from mmengine.logging import MMLogger

from mmagic.utils import try_import

transformers = try_import('transformers')

if transformers is not None:
    from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
    from transformers.models.clip.feature_extraction_clip import \
        CLIPFeatureExtractor  # noqa
    from transformers.models.clip.modeling_clip import CLIPTextModel
    from transformers.models.clip.tokenization_clip import CLIPTokenizer

    logger = MMLogger.get_current_instance()

    def cosine_distance(image_embeds, text_embeds):
        """compute the cosine distance of image embeddings and text
        embeddings."""
        normalized_image_embeds = nn.functional.normalize(image_embeds)
        normalized_text_embeds = nn.functional.normalize(text_embeds)
        return torch.mm(normalized_image_embeds, normalized_text_embeds.t())

    class StableDiffusionSafetyChecker(PreTrainedModel):
        config_class = CLIPConfig
        _no_split_modules = ['CLIPEncoderLayer']

        def __init__(self, config: CLIPConfig):
            """check result image for stable diffusion to prevent NSFW content
            generated.

            Args:
                config(CLIPConfig): config for transformers clip.
            """

            super().__init__(config)

            self.vision_model = CLIPVisionModel(config.vision_config)
            self.visual_projection = nn.Linear(
                config.vision_config.hidden_size,
                config.projection_dim,
                bias=False)

            self.concept_embeds = nn.Parameter(
                torch.ones(17, config.projection_dim), requires_grad=False)
            self.special_care_embeds = nn.Parameter(
                torch.ones(3, config.projection_dim), requires_grad=False)

            self.concept_embeds_weights = nn.Parameter(
                torch.ones(17), requires_grad=False)
            self.special_care_embeds_weights = nn.Parameter(
                torch.ones(3), requires_grad=False)

        @torch.no_grad()
        def forward(self, clip_input, images):
            """return black image if input image has nsfw content.

            Args:
                clip_input(torch.Tensor):
                    image feature extracted by clip feature extractor.
                images(torch.Tensor):
                    image generated by stable diffusion.

            Returns:
                images(torch.Tensor):
                    black images if input images have nsfw content,
                    otherwise return input images.
                has_nsfw_concepts(list[bool]):
                    flag list to indicate whether input images have
                    nsfw content.
            """
            pooled_output = self.vision_model(clip_input)[1]
            image_embeds = self.visual_projection(pooled_output)

            # we always cast to float32 as this does not cause
            # significant overhead and is compatible with bfloa16
            special_cos_dist = cosine_distance(
                image_embeds, self.special_care_embeds).cpu().float().numpy()
            cos_dist = cosine_distance(
                image_embeds, self.concept_embeds).cpu().float().numpy()

            result = []
            batch_size = image_embeds.shape[0]
            for i in range(batch_size):
                result_img = {
                    'special_scores': {},
                    'special_care': [],
                    'concept_scores': {},
                    'bad_concepts': []
                }

                # increase this value to create a stronger `nfsw` filter
                # at the cost of increasing the possibility of
                # filtering benign images
                adjustment = 0.0

                for concept_idx in range(len(special_cos_dist[0])):
                    concept_cos = special_cos_dist[i][concept_idx]
                    concept_threshold = self.special_care_embeds_weights[
                        concept_idx].item()
                    result_img['special_scores'][concept_idx] = round(
                        concept_cos - concept_threshold + adjustment, 3)
                    if result_img['special_scores'][concept_idx] > 0:
                        result_img['special_care'].append({
                            concept_idx,
                            result_img['special_scores'][concept_idx]
                        })
                        adjustment = 0.01

                for concept_idx in range(len(cos_dist[0])):
                    concept_cos = cos_dist[i][concept_idx]
                    concept_threshold = self.concept_embeds_weights[
                        concept_idx].item()
                    result_img['concept_scores'][concept_idx] = round(
                        concept_cos - concept_threshold + adjustment, 3)
                    if result_img['concept_scores'][concept_idx] > 0:
                        result_img['bad_concepts'].append(concept_idx)

                result.append(result_img)

            has_nsfw_concepts = [
                len(res['bad_concepts']) > 0 for res in result
            ]

            for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
                if has_nsfw_concept:
                    images[idx] = torch.zeros(images[idx].shape)  # black image

            if any(has_nsfw_concepts):
                logger.warning(
                    'NSFW content was detected in one or more images.'
                    ' A black image will be returned instead.'
                    ' Try again with a different prompt and/or seed.')

            return images, has_nsfw_concepts

    def load_clip_submodels(init_cfg, submodels, requires_safety_checker):
        """
        Args:
            init_cfg (dict):
                ckpt path of clip models.
            submodels (List):
                list of stable diffusion submodels.
            requires_safety_checker (bool):
                whether to load safety checker

        Returns:
            tokenizer(CLIPTokenizer):
                tokenizer with ckpt loaded.
            feature_extractor(CLIPFeatureExtractor):
                feature_extractor with ckpt loaded.
            text_encoder(CLIPTextModel):
                text_encoder with ckpt loaded.
            safety_checker(StableDiffusionSafetyChecker):
                safety_checker with ckpt loaded.

        """
        pretrained_model_path = init_cfg.get('pretrained_model_path', None)

        tokenizer, feature_extractor, text_encoder, safety_checker = \
            None, None, None, None
        if pretrained_model_path:
            tokenizer = CLIPTokenizer.from_pretrained(
                os.path.join(pretrained_model_path, 'tokenizer'))

            feature_extractor = CLIPFeatureExtractor.from_pretrained(
                os.path.join(pretrained_model_path, 'feature_extractor'))

            text_encoder = CLIPTextModel.from_pretrained(
                os.path.join(pretrained_model_path, 'text_encoder'))

            if requires_safety_checker:
                submodels.append('safety_checker')
                safety_checker = StableDiffusionSafetyChecker.from_pretrained(
                    os.path.join(pretrained_model_path, 'safety_checker'))

        return tokenizer, feature_extractor, text_encoder, safety_checker

else:

    def load_clip_submodels(init_cfg, submodels, requires_safety_checker):
        raise ImportError('Please install transformers via '
                          '\'pip install transformers\'')
